import numpy as np
import pandas as pd
import seaborn as sns
import xlrd
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from scipy.stats import spearmanr
from rdkit.Chem import Crippen
from rdkit.Chem import AllChem
from rdkit import Chem
from MulInfo import *
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score

property = ['Total Molweight','Molweight','Monoisotopic Mass','cLogP','cLogS','H-Acceptors','H-Donors','Total Surface Area','Relative PSA','Polar Surface Area','Druglikeness']

features = ['Molweight', 'Monoisotopic Mass', 'cLogP', 'cLogS','H-Acceptors','H-Donors','Total Surface Area', 'Relative PSA', 'Polar Surface Area', 'Druglikeness']
selected_features = ['Molweight','cLogP','cLogS']
infile = 'control_MulVAE_z'
prop = 'logp'
n = 3369
selected_idx = [features.index(fs) for fs in selected_features]

# calculate property correlations
def calculate_correlation():
    data_mol = []
    workbook = xlrd.open_workbook('data_mols_stats.xlsx')
    sheet = workbook.sheet_by_index(0)
    for rowx in range(sheet.nrows):
        values = sheet.row_values(rowx)
        data_mol.append(values[2:])
    # discard the titles
    data_mol = np.array(data_mol[1:])
    for i in range(len(property)):
        for j in range(i+1,len(property)):
            pear_corr, _ = pearsonr(data_mol[:,i],data_mol[:,j])
            spear_corr, _ = spearmanr(data_mol[:,i],data_mol[:,j])
            if abs(pear_corr) > 0.3:
                print ('--------------------------------------------------')
                print ('pearson\'s correlation between %s and %s: %f' % (property[i], property[j], pear_corr))
                print ('--------------------------------------------------')
            else:
                print ('pearson\'s correlation between %s and %s: %f' % (property[i], property[j], pear_corr))
            if abs(spear_corr) > 0.3:
                print ('--------------------------------------------------')
                print ('spearman\'s correlation between %s and %s: %f' % (property[i], property[j], spear_corr))
                print ('--------------------------------------------------')
            else:
                print ('spearman\'s correlation between %s and %s: %f' % (property[i], property[j], spear_corr))

def correlation_heat_map():
    data_mol = []
    workbook = xlrd.open_workbook('data_mols_stats.xlsx')
    sheet = workbook.sheet_by_index(0)
    for rowx in range(sheet.nrows):
        values = sheet.row_values(rowx)
        data_mol.append(values[2:])
    # remove titles
    titles = np.array(data_mol[0])
    data_mol = np.array(data_mol[1:])
    titles = ['W', 'cLogP', 'cLogS','H-A','H-D','TSA', 'rPSA', 'PSA', 'Drug']
    data_mol = np.delete(data_mol,[0,2],axis=1)
    print (data_mol.shape)
    df = pd.DataFrame(data=data_mol, columns=titles)
    print (df)
    ax = sns.heatmap(df.corr(), annot=True)
    plt.tight_layout()
    plt.savefig("heatmap.pdf")

def read_from_p(infile, prop,n):
    smiles = []
    pred_logp, pred_weight = [], []
    for i in range(n):
        sml = './'+infile+'/smiles_' + str(i) + '.npy'
        pred_logp_n = './'+infile+'/pred_prop_' + str(i) + '.npy'
        pred_weight_n = './'+infile+'/pred_weight_' + str(i) + '.npy'
        sml = np.load(sml,allow_pickle=True).flatten()
        logp_s = np.load(pred_logp_n,allow_pickle=True)
        weight_s = np.load(pred_weight_n,allow_pickle=True)
        if i == 0:
            smiles = sml
            pred_logp = logp_s
            pred_weight = weight_s
        else:
            smiles = np.concatenate((smiles, sml))
            pred_logp = np.concatenate((pred_logp, logp_s), axis=0)
            pred_weight = np.concatenate((pred_weight, weight_s), axis=0)

    data,label = [],[]
    print (pred_logp.shape)
    workbook = xlrd.open_workbook('data_mols_stats.xlsx')
    sheet = workbook.sheet_by_index(0)
    for rowx in range(sheet.nrows):
        values = sheet.row_values(rowx)
        data.append(values[3:])
        label.append(values[1])
    data,label = data[1:],label[1:]
    data = np.array(data)
    print (data.shape, selected_idx)
    data = data[:,selected_idx]
    print(data.shape,len(label))
    selected_data = []
    for smi in smiles:
        idx = label.index(smi)
        #print (idx)
        selected_data.append(data[idx,:])
    selected_data = np.array(selected_data)




    #infile = 'temp'
    #np.save('pred_prop_'+infile,pred_logp[:,0])
    #np.save('pred_weight_'+infile,pred_weight)
    #np.save(prop+'_'+infile,selected_data)

    print ('MSE','logp', mean_squared_error(pred_logp[:,0],selected_data[:,1]))
    print ('MSE','logs', mean_squared_error(pred_logp[:,1],selected_data[:,2]))
    print ('MSE','weight', mean_squared_error(pred_weight,selected_data[:,0]))

def read_from_z(infile, prop,n):
    smiles, z_sampled = [], []
    for i in range(n):
        sml = './'+infile+'/smiles_' + str(i) + '.npy'
        z_s = './'+infile+'/z_sampled_' + str(i) + '.npy'
        sml = np.load(sml,allow_pickle=True).flatten()
        z_s = np.load(z_s,allow_pickle=True)
        if i == 0:
            smiles = sml
            z_sampled = z_s
        else:
            smiles = np.concatenate((smiles, sml))
            z_sampled = np.concatenate((z_sampled, z_s), axis=0)
    data,label = [],[]
    print (z_sampled.shape)
    workbook = xlrd.open_workbook('data_mols_stats.xlsx')
    sheet = workbook.sheet_by_index(0)
    for rowx in range(sheet.nrows):
        values = sheet.row_values(rowx)
        data.append(values[3:])
        label.append(values[1])
    data,label = data[1:],label[1:]
    data = np.array(data)
    print (data.shape, selected_idx)
    data = data[:,selected_idx]
    print(data.shape,len(label))
    selected_data = []
    for smi in smiles:
        idx = label.index(smi)
        #print (idx)
        selected_data.append(data[idx,:])
    selected_data = np.array(selected_data)

    print (z_sampled)
    np.save(prop+'_z_'+infile,z_sampled)
    np.save(prop+'_'+infile,selected_data)

    #print ('MSE','logp', mean_squared_error(z_sampled[:,1],selected_data[:,1]))
    #print ('MSE','logs', r2_score(pred_logp[:,1],selected_data[:,2]))
    #print ('MSE','weight', mean_squared_error(z_sampled[:,0],selected_data[:,0]))


def visualize():
    data = []
    workbook = xlrd.open_workbook('generated_smiles_qm9_controlVAE.xlsx')
    sheet = workbook.sheet_by_index(0)
    for rowx in range(sheet.nrows):
        values = sheet.row_values(rowx)
        data.append(values[2:])
    data = data[1:]
    data = np.array(data)
    data1 = data[:100,0]
    data2 = data[100:200,2]
    print(data.shape)
#data2 = []
    fig, ax = plt.subplots()
    ax.plot(data1,label='weight')
    ax.plot(data2,label='clogP')
    ax.set_title('latent variable z vs property curve',fontsize=20)
    ax.set_xlabel('z value',fontsize=20)
    ax.set_ylabel('property value',fontsize=20)
    plt.legend()
    plt.tick_params(
                    axis='both',          # changes apply to the x-axis
                    which='both',      # both major and minor ticks are affected
                    left=False,
                    labelleft=False,
                    bottom=False,      # ticks along the bottom edge are off
                    top=False,         # ticks along the top edge are off
                    labelbottom=False)
    plt.tight_layout()
    fig_name = './controlVAE_z_vs_property.png'
    fig.savefig(fig_name)

if __name__ == '__main__':
    read_from_z(infile,prop,n)
    MIG_compute(infile,prop)
